{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Spark Dataframe\n",
    "\n",
    "This notebook shows how to use agents to interact with a `Spark DataFrame` and `Spark Connect`. It is mostly optimized for question answering.\n",
    "\n",
    "**NOTE: this agent calls the Python agent under the hood, which executes LLM generated Python code - this can be bad if the LLM generated Python code is harmful. Use cautiously.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Spark DataFrame` example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "23/05/15 20:33:10 WARN Utils: Your hostname, Mikes-Mac-mini.local resolves to a loopback address: 127.0.0.1; using 192.168.68.115 instead (on interface en1)\n",
      "23/05/15 20:33:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/05/15 20:33:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|\n",
      "|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|\n",
      "|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|\n",
      "|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|\n",
      "|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|\n",
      "|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|          330877| 8.4583| null|       Q|\n",
      "|          7|       0|     1|McCarthy, Mr. Tim...|  male|54.0|    0|    0|           17463|51.8625|  E46|       S|\n",
      "|          8|       0|     3|Palsson, Master. ...|  male| 2.0|    3|    1|          349909| 21.075| null|       S|\n",
      "|          9|       1|     3|Johnson, Mrs. Osc...|female|27.0|    0|    2|          347742|11.1333| null|       S|\n",
      "|         10|       1|     2|Nasser, Mrs. Nich...|female|14.0|    1|    0|          237736|30.0708| null|       C|\n",
      "|         11|       1|     3|Sandstrom, Miss. ...|female| 4.0|    1|    1|         PP 9549|   16.7|   G6|       S|\n",
      "|         12|       1|     1|Bonnell, Miss. El...|female|58.0|    0|    0|          113783|  26.55| C103|       S|\n",
      "|         13|       0|     3|Saundercock, Mr. ...|  male|20.0|    0|    0|       A/5. 2151|   8.05| null|       S|\n",
      "|         14|       0|     3|Andersson, Mr. An...|  male|39.0|    1|    5|          347082| 31.275| null|       S|\n",
      "|         15|       0|     3|Vestrom, Miss. Hu...|female|14.0|    0|    0|          350406| 7.8542| null|       S|\n",
      "|         16|       1|     2|Hewlett, Mrs. (Ma...|female|55.0|    0|    0|          248706|   16.0| null|       S|\n",
      "|         17|       0|     3|Rice, Master. Eugene|  male| 2.0|    4|    1|          382652| 29.125| null|       Q|\n",
      "|         18|       1|     2|Williams, Mr. Cha...|  male|null|    0|    0|          244373|   13.0| null|       S|\n",
      "|         19|       0|     3|Vander Planke, Mr...|female|31.0|    1|    0|          345763|   18.0| null|       S|\n",
      "|         20|       1|     3|Masselmani, Mrs. ...|female|null|    0|    0|            2649|  7.225| null|       C|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from langchain_experimental.agents.agent_toolkits import create_spark_dataframe_agent\n",
    "from langchain_openai import OpenAI\n",
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.getOrCreate()\n",
    "csv_file_path = \"titanic.csv\"\n",
    "df = spark.read.csv(csv_file_path, header=True, inferSchema=True)\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n",
      "Action: python_repl_ast\n",
      "Action Input: df.count()\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
      "Final Answer: There are 891 rows in the dataframe.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'There are 891 rows in the dataframe.'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.run(\"how many rows are there?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: I need to find out how many people have more than 3 siblings\n",
      "Action: python_repl_ast\n",
      "Action Input: df.filter(df.SibSp > 3).count()\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
      "Final Answer: 30 people have more than 3 siblings.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'30 people have more than 3 siblings.'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.run(\"how many people have more than 3 siblings\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: I need to get the average age first\n",
      "Action: python_repl_ast\n",
      "Action Input: df.agg({\"Age\": \"mean\"}).collect()[0][0]\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now have the average age, I need to get the square root\n",
      "Action: python_repl_ast\n",
      "Action Input: math.sqrt(29.69911764705882)\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3mname 'math' is not defined\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I need to import math first\n",
      "Action: python_repl_ast\n",
      "Action Input: import math\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now have the math library imported, I can get the square root\n",
      "Action: python_repl_ast\n",
      "Action Input: math.sqrt(29.69911764705882)\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
      "Final Answer: 5.449689683556195\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'5.449689683556195'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.run(\"whats the square root of the average age?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Spark Connect` example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# in apache-spark root directory. (tested here with \"spark-3.4.0-bin-hadoop3 and later\")\n",
    "# To launch Spark with support for Spark Connect sessions, run the start-connect-server.sh script.\n",
    "!./sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.4.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "23/05/08 10:06:09 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "# Now that the Spark server is running, we can connect to it remotely using Spark Connect. We do this by\n",
    "# creating a remote Spark session on the client where our application runs. Before we can do that, we need\n",
    "# to make sure to stop the existing regular Spark session because it cannot coexist with the remote\n",
    "# Spark Connect session we are about to create.\n",
    "SparkSession.builder.master(\"local[*]\").getOrCreate().stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The command we used above to launch the server configured Spark to run as localhost:15002.\n",
    "# So now we can create a remote Spark session on the client using the following command.\n",
    "spark = SparkSession.builder.remote(\"sc://localhost:15002\").getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|\n",
      "|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|\n",
      "|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|\n",
      "|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|\n",
      "|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|\n",
      "|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|          330877| 8.4583| null|       Q|\n",
      "|          7|       0|     1|McCarthy, Mr. Tim...|  male|54.0|    0|    0|           17463|51.8625|  E46|       S|\n",
      "|          8|       0|     3|Palsson, Master. ...|  male| 2.0|    3|    1|          349909| 21.075| null|       S|\n",
      "|          9|       1|     3|Johnson, Mrs. Osc...|female|27.0|    0|    2|          347742|11.1333| null|       S|\n",
      "|         10|       1|     2|Nasser, Mrs. Nich...|female|14.0|    1|    0|          237736|30.0708| null|       C|\n",
      "|         11|       1|     3|Sandstrom, Miss. ...|female| 4.0|    1|    1|         PP 9549|   16.7|   G6|       S|\n",
      "|         12|       1|     1|Bonnell, Miss. El...|female|58.0|    0|    0|          113783|  26.55| C103|       S|\n",
      "|         13|       0|     3|Saundercock, Mr. ...|  male|20.0|    0|    0|       A/5. 2151|   8.05| null|       S|\n",
      "|         14|       0|     3|Andersson, Mr. An...|  male|39.0|    1|    5|          347082| 31.275| null|       S|\n",
      "|         15|       0|     3|Vestrom, Miss. Hu...|female|14.0|    0|    0|          350406| 7.8542| null|       S|\n",
      "|         16|       1|     2|Hewlett, Mrs. (Ma...|female|55.0|    0|    0|          248706|   16.0| null|       S|\n",
      "|         17|       0|     3|Rice, Master. Eugene|  male| 2.0|    4|    1|          382652| 29.125| null|       Q|\n",
      "|         18|       1|     2|Williams, Mr. Cha...|  male|null|    0|    0|          244373|   13.0| null|       S|\n",
      "|         19|       0|     3|Vander Planke, Mr...|female|31.0|    1|    0|          345763|   18.0| null|       S|\n",
      "|         20|       1|     3|Masselmani, Mrs. ...|female|null|    0|    0|            2649|  7.225| null|       C|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "csv_file_path = \"titanic.csv\"\n",
    "df = spark.read.csv(csv_file_path, header=True, inferSchema=True)\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from langchain_experimental.agents import create_spark_dataframe_agent\n",
    "from langchain_openai import OpenAI\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\"\n",
    "\n",
    "agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Thought: I need to find the row with the highest fare\n",
      "Action: python_repl_ast\n",
      "Action Input: df.sort(df.Fare.desc()).first()\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3mRow(PassengerId=259, Survived=1, Pclass=1, Name='Ward, Miss. Anna', Sex='female', Age=35.0, SibSp=0, Parch=0, Ticket='PC 17755', Fare=512.3292, Cabin=None, Embarked='C')\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now know the name of the person who bought the most expensive ticket\n",
      "Final Answer: Miss. Anna Ward\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Miss. Anna Ward'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.run(\n",
    "    \"\"\"\n",
    "who bought the most expensive ticket?\n",
    "You can find all supported function types in https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/dataframe\n",
    "\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
